{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t09eeeR5prIJ"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "GCCk8_dHpuNf"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ovpZyIhNIgoq"
      },
      "source": [
        "# RNN によるテキスト生成"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hcD2nPQvPOFM"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/text/text_generation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/tutorials/text/text_generation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/tutorials/text/text_generation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/tutorials/text/text_generation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RoRi17kuCYFs"
      },
      "source": [
        "Note: これらのドキュメントは私たちTensorFlowコミュニティが翻訳したものです。コミュニティによる 翻訳は**ベストエフォート**であるため、この翻訳が正確であることや[英語の公式ドキュメント](https://www.tensorflow.org/?hl=en)の 最新の状態を反映したものであることを保証することはできません。 この翻訳の品質を向上させるためのご意見をお持ちの方は、GitHubリポジトリ[tensorflow/docs](https://github.com/tensorflow/docs)にプルリクエストをお送りください。 コミュニティによる翻訳やレビューに参加していただける方は、 [docs-ja@tensorflow.org メーリングリスト](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ja)にご連絡ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BwpJ5IffzRG6"
      },
      "source": [
        "このチュートリアルでは、文字ベースの RNN を使ってテキストを生成する方法を示します。ここでは、Andrej Karpathy の [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) からのシェイクスピア作品のデータセットを使います。このデータからの文字列（\"Shakespear\"）を入力にして、文字列中の次の文字（\"e\"）を予測するモデルを訓練します。このモデルを繰り返し呼び出すことで、より長い文字列を生成することができます。\n",
        "\n",
        "Note: このノートブックの実行を速くするために GPU による高速化を有効にしてください。Colab では、*ランタイム　＞　ランタイムのタイプを変更 ＞ ハードウェアアクセラレータ ＞ GPU* を選択します。ローカルで実行する場合には、TensorFlow のバージョンが 1.11 以降であることを確認してください。\n",
        "\n",
        "このチュートリアルには、[tf.keras](https://www.tensorflow.org/programmers_guide/keras) と [eager execution](https://www.tensorflow.org/programmers_guide/eager) を使ったコードが含まれています。下記は、このチュートリアルのモデルを 30 エポック訓練したものに対して、文字列 \"Q\" を初期値とした場合の出力例です。\n",
        "\n",
        "<pre>\n",
        "QUEENE:\n",
        "I had thought thou hadst a Roman; for the oracle,\n",
        "Thus by All bids the man against the word,\n",
        "Which are so weak of care, by old care done;\n",
        "Your children were in your holy love,\n",
        "And the precipitation through the bleeding throne.\n",
        "\n",
        "BISHOP OF ELY:\n",
        "Marry, and will, my lord, to weep in such a one were prettiest;\n",
        "Yet now I was adopted heir\n",
        "Of the world's lamentable day,\n",
        "To watch the next way with his father with his face?\n",
        "\n",
        "ESCALUS:\n",
        "The cause why then we are all resolved more sons.\n",
        "\n",
        "VOLUMNIA:\n",
        "O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,\n",
        "And love and pale as any will to that word.\n",
        "\n",
        "QUEEN ELIZABETH:\n",
        "But how long have I heard the soul for this world,\n",
        "And show his hands of life be proved to stand.\n",
        "\n",
        "PETRUCHIO:\n",
        "I say he look'd on, if I must be content\n",
        "To stay him from the fatal of our country's bliss.\n",
        "His lordship pluck'd from this sentence then for prey,\n",
        "And then let us twain, being the moon,\n",
        "were she such a case as fills m\n",
        "</pre>\n",
        "\n",
        "いくつかは文法にあったものがある一方で、ほとんどは意味をなしていません。このモデルは、単語の意味を学習していませんが、次のことを考えてみてください。\n",
        "\n",
        "* このモデルは文字ベースです。訓練が始まった時に、モデルは英語の単語のスペルも知りませんし、単語がテキストの単位であることも知らないのです。\n",
        "\n",
        "* 出力の構造は戯曲に似ています。だいたいのばあい、データセットとおなじ大文字で書かれた話し手の名前で始まっています。\n",
        "\n",
        "* 以下に示すように、モデルはテキストの小さなバッチ（各100文字）で訓練されていますが、一貫した構造のより長いテキストのシーケンスを生成できます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "srXC6pLGLwS6"
      },
      "source": [
        "## 設定"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WGyKZj3bzf9p"
      },
      "source": [
        "### TensorFlow 等のライブラリインポート"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yG_n40gFzf9s"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import numpy as np\n",
        "import os\n",
        "import time"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EHDoRoc5PKWz"
      },
      "source": [
        "### シェイクスピアデータセットのダウンロード\n",
        "\n",
        "独自のデータで実行するためには下記の行を変更してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pD_55cOxLkAb"
      },
      "outputs": [],
      "source": [
        "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UHjdCjDuSvX_"
      },
      "source": [
        "### データの読み込み\n",
        "\n",
        "まずはテキストをのぞいてみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aavnuByVymwK"
      },
      "outputs": [],
      "source": [
        "# 読み込んだのち、Python 2 との互換性のためにデコード\n",
        "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
        "# テキストの長さは含まれる文字数\n",
        "print ('Length of text: {} characters'.format(len(text)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Duhg9NrUymwO"
      },
      "outputs": [],
      "source": [
        "# テキストの最初の 250文字を参照\n",
        "print(text[:250])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IlCgQBRVymwR"
      },
      "outputs": [],
      "source": [
        "# ファイル中のユニークな文字の数\n",
        "vocab = sorted(set(text))\n",
        "print ('{} unique characters'.format(len(vocab)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rNnrKn_lL-IJ"
      },
      "source": [
        "## テキストの処理"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LFjSVAlWzf-N"
      },
      "source": [
        "### テキストのベクトル化\n",
        "\n",
        "訓練をする前に、文字列を数値表現に変換する必要があります。2つの参照テーブルを作成します。一つは文字を数字に変換するもの、もう一つは数字を文字に変換するものです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IalZLbvOzf-F"
      },
      "outputs": [],
      "source": [
        "# それぞれの文字からインデックスへの対応表を作成\n",
        "char2idx = {u:i for i, u in enumerate(vocab)}\n",
        "idx2char = np.array(vocab)\n",
        "\n",
        "text_as_int = np.array([char2idx[c] for c in text])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tZfqhkYCymwX"
      },
      "source": [
        "これで、それぞれの文字を整数で表現できました。文字を、0 から`len(unique)` までのインデックスに変換していることに注意してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FYyNlCNXymwY"
      },
      "outputs": [],
      "source": [
        "print('{')\n",
        "for char,_ in zip(char2idx, range(20)):\n",
        "    print('  {:4s}: {:3d},'.format(repr(char), char2idx[char]))\n",
        "print('  ...\\n}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l1VKcQHcymwb"
      },
      "outputs": [],
      "source": [
        "# テキストの最初の 13 文字がどのように整数に変換されるかを見てみる\n",
        "print ('{} ---- characters mapped to int ---- > {}'.format(repr(text[:13]), text_as_int[:13]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bbmsf23Bymwe"
      },
      "source": [
        "### 予測タスク"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wssHQ1oGymwe"
      },
      "source": [
        "ある文字、あるいは文字列が与えられたとき、もっともありそうな次の文字はなにか？これが、モデルを訓練してやらせたいタスクです。モデルへの入力は文字列であり、モデルが出力、つまりそれぞれの時点での次の文字を予測をするようにモデルを訓練します。\n",
        "\n",
        "RNN はすでに見た要素に基づく内部状態を保持しているため、この時点までに計算されたすべての文字を考えると、次の文字は何でしょうか？"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hgsVvVxnymwf"
      },
      "source": [
        "### 訓練用サンプルとターゲットを作成\n",
        "\n",
        "つぎに、テキストをサンプルシーケンスに分割します。それぞれの入力シーケンスは、元のテキストからの `seq_length` 個の文字を含みます。\n",
        "\n",
        "入力シーケンスそれぞれに対して、対応するターゲットは同じ長さのテキストを含みますが、1文字ずつ右にシフトしたものです。\n",
        "\n",
        "そのため、テキストを `seq_length+1` のかたまりに分割します。たとえば、 `seq_length` が 4 で、テキストが \"Hello\" だとします。入力シーケンスは \"Hell\" で、ターゲットシーケンスは \"ello\" となります。\n",
        "\n",
        "これを行うために、最初に `tf.data.Dataset.from_tensor_slices` 関数を使ってテキストベクトルを文字インデックスの連続に変換します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0UHJDA39zf-O"
      },
      "outputs": [],
      "source": [
        "# ひとつの入力としたいシーケンスの文字数としての最大の長さ\n",
        "seq_length = 100\n",
        "examples_per_epoch = len(text)//(seq_length+1)\n",
        "\n",
        "# 訓練用サンプルとターゲットを作る\n",
        "char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)\n",
        "\n",
        "for i in char_dataset.take(5):\n",
        "  print(idx2char[i.numpy()])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ZSYAcQV8OGP"
      },
      "source": [
        "`batch` メソッドを使うと、個々の文字を求める長さのシーケンスに簡単に変換できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l4hkDU3i7ozi"
      },
      "outputs": [],
      "source": [
        "sequences = char_dataset.batch(seq_length+1, drop_remainder=True)\n",
        "\n",
        "for item in sequences.take(5):\n",
        "  print(repr(''.join(idx2char[item.numpy()])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UbLcIPBj_mWZ"
      },
      "source": [
        "シーケンスそれぞれに対して、`map` メソッドを使って各バッチに単純な関数を適用することで、複製とシフトを行い、入力テキストとターゲットテキストを生成します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9NGu-FkO_kYU"
      },
      "outputs": [],
      "source": [
        "def split_input_target(chunk):\n",
        "    input_text = chunk[:-1]\n",
        "    target_text = chunk[1:]\n",
        "    return input_text, target_text\n",
        "\n",
        "dataset = sequences.map(split_input_target)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hiCopyGZymwi"
      },
      "source": [
        "最初のサンプルの入力とターゲットを出力します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GNbw-iR0ymwj"
      },
      "outputs": [],
      "source": [
        "for input_example, target_example in  dataset.take(1):\n",
        "  print ('Input data: ', repr(''.join(idx2char[input_example.numpy()])))\n",
        "  print ('Target data:', repr(''.join(idx2char[target_example.numpy()])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_33OHL3b84i0"
      },
      "source": [
        "これらのベクトルのインデックスそれぞれが一つのタイムステップとして処理されます。タイムステップ 0 の入力として、モデルは \"F\" のインデックスを受け取り、次の文字として \"i\" のインデックスを予測しようとします。次のタイムステップでもおなじことをしますが、`RNN` は現在の入力文字に加えて、過去のステップのコンテキストも考慮します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0eBu9WZG84i0"
      },
      "outputs": [],
      "source": [
        "for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):\n",
        "    print(\"Step {:4d}\".format(i))\n",
        "    print(\"  input: {} ({:s})\".format(input_idx, repr(idx2char[input_idx])))\n",
        "    print(\"  expected output: {} ({:s})\".format(target_idx, repr(idx2char[target_idx])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MJdfPmdqzf-R"
      },
      "source": [
        "### 訓練用バッチの作成\n",
        "\n",
        "`tf.data` を使ってテキストを分割し、扱いやすいシーケンスにします。しかし、このデータをモデルに供給する前に、データをシャッフルしてバッチにまとめる必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p2pGotuNzf-S"
      },
      "outputs": [],
      "source": [
        "# バッチサイズ\n",
        "BATCH_SIZE = 64\n",
        "\n",
        "# データセットをシャッフルするためのバッファサイズ\n",
        "# （TF data は可能性として無限長のシーケンスでも使えるように設計されています。\n",
        "# このため、シーケンス全体をメモリ内でシャッフルしようとはしません。\n",
        "# その代わりに、要素をシャッフルするためのバッファを保持しています）\n",
        "BUFFER_SIZE = 10000\n",
        "\n",
        "dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n",
        "\n",
        "dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r6oUuElIMgVx"
      },
      "source": [
        "## モデルの構築"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m8gPwEjRzf-Z"
      },
      "source": [
        "`tf.keras.Sequential` を使ってモデルを定義します。この簡単な例では、モデルの定義に3つのレイヤーを使用しています。\n",
        "\n",
        "* `tf.keras.layers.Embedding`: 入力レイヤー。それぞれの文字を表す数を `embedding_dim`　次元のベクトルに変換する、訓練可能な参照テーブル。\n",
        "* `tf.keras.layers.GRU`: サイズが `units=rnn_units` のRNNの一種（ここに LSTM レイヤーを使うこともできる）。\n",
        "* `tf.keras.layers.Dense`: `vocab_size` の出力を持つ、出力レイヤー。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zHT8cLh7EAsg"
      },
      "outputs": [],
      "source": [
        "# 文字数で表されるボキャブラリーの長さ\n",
        "vocab_size = len(vocab)\n",
        "\n",
        "# 埋め込みベクトルの次元\n",
        "embedding_dim = 256\n",
        "\n",
        "# RNN ユニットの数\n",
        "rnn_units = 1024"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MtCrdfzEI2N0"
      },
      "outputs": [],
      "source": [
        "def build_model(vocab_size, embedding_dim, rnn_units, batch_size):\n",
        "  model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Embedding(vocab_size, embedding_dim,\n",
        "                              batch_input_shape=[batch_size, None]),\n",
        "    tf.keras.layers.GRU(rnn_units,\n",
        "                        return_sequences=True,\n",
        "                        stateful=True,\n",
        "                        recurrent_initializer='glorot_uniform'),\n",
        "    tf.keras.layers.Dense(vocab_size)\n",
        "  ])\n",
        "  return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wwsrpOik5zhv"
      },
      "outputs": [],
      "source": [
        "model = build_model(\n",
        "  vocab_size = len(vocab),\n",
        "  embedding_dim=embedding_dim,\n",
        "  rnn_units=rnn_units,\n",
        "  batch_size=BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RkA5upJIJ7W7"
      },
      "source": [
        "1文字ごとにモデルは埋め込みベクトルを検索し、その埋め込みベクトルを入力として GRU を 1 タイムステップ実行します。そして Dense レイヤーを適用して、次の文字の対数尤度を予測するロジットを生成します。\n",
        "\n",
        "![A drawing of the data passing through the model](https://github.com/masa-ita/tf-docs/blob/site_ja_tutorials_text_text_generation/site/ja/tutorials/text/images/text_generation_training.png?raw=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-ubPo0_9Prjb"
      },
      "source": [
        "## モデルを試す\n",
        "\n",
        "期待通りに動作するかどうかを確認するためモデルを動かしてみましょう。\n",
        "\n",
        "最初に、出力の shape を確認します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C-_70kKAPrPU"
      },
      "outputs": [],
      "source": [
        "for input_example_batch, target_example_batch in dataset.take(1):\n",
        "  example_batch_predictions = model(input_example_batch)\n",
        "  print(example_batch_predictions.shape, \"# (batch_size, sequence_length, vocab_size)\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q6NzLBi4VM4o"
      },
      "source": [
        "上記の例では、入力のシーケンスの長さは `100` ですが、モデルはどのような長さの入力でも実行できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vPGmAAXmVLGC"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uwv0gEkURfx1"
      },
      "source": [
        "モデルから実際の予測を得るには出力の分布からサンプリングを行う必要があります。この分布は、文字ボキャブラリー全体のロジットで定義されます。\n",
        "\n",
        "Note: この分布から _サンプリング_ するということが重要です。なぜなら、分布の _argmax_ をとったのでは、モデルは簡単にループしてしまうからです。\n",
        "\n",
        "バッチ中の最初のサンプルで試してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4V4MfFg0RQJg"
      },
      "outputs": [],
      "source": [
        "sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)\n",
        "sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QM1Vbxs_URw5"
      },
      "source": [
        "これにより、タイムステップそれぞれにおいて、次の文字のインデックスの予測が得られます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YqFMUQc_UFgM"
      },
      "outputs": [],
      "source": [
        "sampled_indices"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LfLtsP3mUhCG"
      },
      "source": [
        "これらをデコードすることで、この訓練前のモデルによる予測テキストをみることができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xWcFwPwLSo05"
      },
      "outputs": [],
      "source": [
        "print(\"Input: \\n\", repr(\"\".join(idx2char[input_example_batch[0]])))\n",
        "print()\n",
        "print(\"Next Char Predictions: \\n\", repr(\"\".join(idx2char[sampled_indices ])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LJL0Q0YPY6Ee"
      },
      "source": [
        "## モデルの訓練"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YCbHQHiaa4Ic"
      },
      "source": [
        "ここまでくれば問題は標準的な分類問題として扱うことができます。これまでの RNN の状態と、いまのタイムステップの入力が与えられ、次の文字のクラスを予測します。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "trpqTWyvk0nr"
      },
      "source": [
        "### オプティマイザと損失関数の付加"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UAjbjY03eiQ4"
      },
      "source": [
        "この場合、標準の `tf.keras.losses.sparse_categorical_crossentropy` 損失関数が使えます。予測の最後の次元に適用されるからです。\n",
        "\n",
        "このモデルはロジットを返すので、`from_logits` フラグをセットする必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4HrXTACTdzY-"
      },
      "outputs": [],
      "source": [
        "def loss(labels, logits):\n",
        "  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)\n",
        "\n",
        "example_batch_loss  = loss(target_example_batch, example_batch_predictions)\n",
        "print(\"Prediction shape: \", example_batch_predictions.shape, \" # (batch_size, sequence_length, vocab_size)\")\n",
        "print(\"scalar_loss:      \", example_batch_loss.numpy().mean())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jeOXriLcymww"
      },
      "source": [
        "`tf.keras.Model.compile` を使って、訓練手順を定義します。既定の引数を持った `tf.keras.optimizers.Adam` と、先ほどの loss 関数を使用しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DDl1_Een6rL0"
      },
      "outputs": [],
      "source": [
        "model.compile(optimizer='adam', loss=loss)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ieSJdchZggUj"
      },
      "source": [
        "### チェックポイントの構成"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C6XBUUavgF56"
      },
      "source": [
        "`tf.keras.callbacks.ModelCheckpoint` を使って、訓練中にチェックポイントを保存するようにします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W6fWTriUZP-n"
      },
      "outputs": [],
      "source": [
        "# チェックポイントが保存されるディレクトリ\n",
        "checkpoint_dir = './training_checkpoints'\n",
        "# チェックポイントファイルの名称\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n",
        "\n",
        "checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(\n",
        "    filepath=checkpoint_prefix,\n",
        "    save_weights_only=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Ky3F_BhgkTW"
      },
      "source": [
        "### 訓練の実行"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IxdOA-rgyGvs"
      },
      "source": [
        "訓練時間を適切に保つために、10エポックを使用してモデルを訓練します。Google Colab を使用する場合には、訓練を高速化するためにランタイムを GPU に設定します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7yGBE2zxMMHs"
      },
      "outputs": [],
      "source": [
        "EPOCHS=10"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UK-hmKjYVoll"
      },
      "outputs": [],
      "source": [
        "history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kKkD5M6eoSiN"
      },
      "source": [
        "## テキスト生成"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JIPcXllKjkdr"
      },
      "source": [
        "### 最終チェックポイントの復元"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LyeYRiuVjodY"
      },
      "source": [
        "予測ステップを単純にするため、バッチサイズ 1 を使用します。\n",
        "\n",
        "RNN が状態をタイムステップからタイムステップへと渡す仕組みのため、モデルは一度構築されると固定されたバッチサイズしか受け付けられません。\n",
        "\n",
        "モデルを異なる `batch_size` で実行するためには、モデルを再構築し、チェックポイントから重みを復元する必要があります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zk2WJ2-XjkGz"
      },
      "outputs": [],
      "source": [
        "tf.train.latest_checkpoint(checkpoint_dir)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LycQ-ot_jjyu"
      },
      "outputs": [],
      "source": [
        "model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)\n",
        "\n",
        "model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))\n",
        "\n",
        "model.build(tf.TensorShape([1, None]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "71xa6jnYVrAN"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DjGz1tDkzf-u"
      },
      "source": [
        "### 予測ループ\n",
        "\n",
        "下記のコードブロックでテキストを生成します。\n",
        "\n",
        "* 最初に、開始文字列を選択し、RNN の状態を初期化して、生成する文字数を設定します。\n",
        "\n",
        "* 開始文字列と RNN の状態を使って、次の文字の予測分布を得ます。\n",
        "\n",
        "* つぎに、カテゴリー分布を使用して、予測された文字のインデックスを計算します。この予測された文字をモデルの次の入力にします。\n",
        "\n",
        "* モデルによって返された RNN の状態はモデルにフィードバックされるため、1つの文字だけでなく、より多くのコンテキストを持つことになります。つぎの文字を予測した後、更新された RNN の状態が再びモデルにフィードバックされます。こうしてモデルは以前に予測した文字からさらにコンテキストを得ることで学習するのです。\n",
        "\n",
        "![To generate text the model's output is fed back to the input](https://github.com/masa-ita/tf-docs/blob/site_ja_tutorials_text_text_generation/site/ja/tutorials/text/images/text_generation_sampling.png?raw=1)\n",
        "\n",
        "生成されたテキストを見ると、モデルがどこを大文字にするかや、段落の区切り方、シェークスピアらしい書き言葉を真似ることを知っていることがわかります。しかし、訓練のエポック数が少ないので、まだ一貫した文章を生成するところまでは学習していません。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WvuwZBX5Ogfd"
      },
      "outputs": [],
      "source": [
        "def generate_text(model, start_string):\n",
        "  # 評価ステップ（学習済みモデルを使ったテキスト生成）\n",
        "\n",
        "  # 生成する文字数\n",
        "  num_generate = 1000\n",
        "\n",
        "  # 開始文字列を数値に変換（ベクトル化）\n",
        "  input_eval = [char2idx[s] for s in start_string]\n",
        "  input_eval = tf.expand_dims(input_eval, 0)\n",
        "\n",
        "  # 結果を保存する空文字列\n",
        "  text_generated = []\n",
        "\n",
        "  # 低い temperature　は、より予測しやすいテキストをもたらし\n",
        "  # 高い temperature は、より意外なテキストをもたらす\n",
        "  # 実験により最適な設定を見つけること\n",
        "  temperature = 1.0\n",
        "\n",
        "  # ここではバッチサイズ　== 1\n",
        "  model.reset_states()\n",
        "  for i in range(num_generate):\n",
        "      predictions = model(input_eval)\n",
        "      # バッチの次元を削除\n",
        "      predictions = tf.squeeze(predictions, 0)\n",
        "\n",
        "      # カテゴリー分布をつかってモデルから返された文字を予測 \n",
        "      predictions = predictions / temperature\n",
        "      predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()\n",
        "\n",
        "      # 過去の隠れ状態とともに予測された文字をモデルへのつぎの入力として渡す\n",
        "      input_eval = tf.expand_dims([predicted_id], 0)\n",
        "\n",
        "      text_generated.append(idx2char[predicted_id])\n",
        "\n",
        "  return (start_string + ''.join(text_generated))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ktovv0RFhrkn"
      },
      "outputs": [],
      "source": [
        "print(generate_text(model, start_string=u\"ROMEO: \"))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AM2Uma_-yVIq"
      },
      "source": [
        "この結果を改善するもっとも簡単な方法は、もっと長く訓練することです（`EPOCHS=30` を試してみましょう）。\n",
        "\n",
        "また、異なる初期文字列を使ったり、モデルの精度を向上させるためにもうひとつ RNN レイヤーを加えたり、temperature パラメータを調整して、よりランダム性の強い、あるいは、弱い予測を試してみたりすることができます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y4QwTjAM6A2O"
      },
      "source": [
        "## 上級編： 訓練のカスタマイズ\n",
        "\n",
        "上記の訓練手順は単純ですが、制御できるところがそれほどありません。\n",
        "\n",
        "モデルを手動で実行する方法を見てきたので、訓練ループを展開し、自分で実装してみましょう。このことが、たとえばモデルのオープンループによる出力を安定化するための _カリキュラム学習_ を実装するための出発点になります。\n",
        "\n",
        "勾配を追跡するために `tf.GradientTape` を使用します。このアプローチについての詳細を学ぶには、 [eager execution guide](https://www.tensorflow.org/guide/eager) をお読みください。\n",
        "\n",
        "この手順は下記のように動作します。\n",
        "\n",
        "* 最初に、RNN の状態を初期化する。`tf.keras.Model.reset_states` メソッドを呼び出すことでこれを実行する。\n",
        "\n",
        "* つぎに、（1バッチずつ）データセットを順番に処理し、それぞれのバッチに対する*予測値*を計算する。\n",
        "\n",
        "* `tf.GradientTape` をオープンし、そのコンテキストで、予測値と損失を計算する。\n",
        "\n",
        "* `tf.GradientTape.grads` メソッドを使って、モデルの変数に対する損失の勾配を計算する。\n",
        "\n",
        "* 最後に、オプティマイザの `tf.train.Optimizer.apply_gradients` メソッドを使って、逆方向の処理を行う。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_XAm7eCoKULT"
      },
      "outputs": [],
      "source": [
        "model = build_model(\n",
        "  vocab_size = len(vocab),\n",
        "  embedding_dim=embedding_dim,\n",
        "  rnn_units=rnn_units,\n",
        "  batch_size=BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qUKhnZtMVpoJ"
      },
      "outputs": [],
      "source": [
        "optimizer = tf.keras.optimizers.Adam()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b4kH1o0leVIp"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def train_step(inp, target):\n",
        "  with tf.GradientTape() as tape:\n",
        "    predictions = model(inp)\n",
        "    loss = tf.reduce_mean(\n",
        "        tf.keras.losses.sparse_categorical_crossentropy(\n",
        "            target, predictions, from_logits=True))\n",
        "  grads = tape.gradient(loss, model.trainable_variables)\n",
        "  optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
        "\n",
        "  return loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d4tSNwymzf-q"
      },
      "outputs": [],
      "source": [
        "# 訓練ステップ\n",
        "EPOCHS = 10\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "  start = time.time()\n",
        "\n",
        "  # 各エポックの最初に、隠れ状態を初期化する\n",
        "  # 最初は隠れ状態は None\n",
        "  hidden = model.reset_states()\n",
        "\n",
        "  for (batch_n, (inp, target)) in enumerate(dataset):\n",
        "    loss = train_step(inp, target)\n",
        "\n",
        "    if batch_n % 100 == 0:\n",
        "      template = 'Epoch {} Batch {} Loss {}'\n",
        "      print(template.format(epoch+1, batch_n, loss))\n",
        "\n",
        "  # 5エポックごとにモデル（のチェックポイント）を保存する\n",
        "  if (epoch + 1) % 5 == 0:\n",
        "    model.save_weights(checkpoint_prefix.format(epoch=epoch))\n",
        "\n",
        "  print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n",
        "  print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))\n",
        "\n",
        "model.save_weights(checkpoint_prefix.format(epoch=epoch))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "text_generation.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
